# coding: utf-8

from __future__ import print_function

import os
from tensorflow import keras as kr
import torch
from torch import nn
from cnews_loader import read_category, read_vocab
from torch_model import TextCNN, TextRNN
from torch.autograd import Variable
import numpy as np

try:
    bool(type(unicode))
except NameError:
    unicode = str

base_dir = 'cnews'
vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')


class CnnModel:
    def __init__(self):
        self.categories, self.cat_to_id = read_category()
        self.words, self.word_to_id = read_vocab(vocab_dir)
        self.model = TextCNN()
        self.model.load_state_dict(torch.load('model_params.pkl'))

    def predict(self, message):
        # 支持不论在python2还是python3下训练的模型都可以在2或者3的环境下运行
        content = unicode(message)
        data = [self.word_to_id[x] for x in content if x in self.word_to_id]
        data = kr.preprocessing.sequence.pad_sequences([data], 600)
        data = torch.LongTensor(data)
        y_pred_cls = self.model(data)
        print(y_pred_cls)
        class_index = torch.argmax(y_pred_cls[0]).item()
        return self.categories[class_index]


class RnnModel:
    def __init__(self):
        self.categories, self.cat_to_id = read_category()
        self.words, self.word_to_id = read_vocab(vocab_dir)
        self.model = TextRNN()
        self.model.load_state_dict(torch.load('model_params.pkl'))

    def predict(self, message):
        # 支持不论在python2还是python3下训练的模型都可以在2或者3的环境下运行
        content = unicode(message)
        data = [self.word_to_id[x] for x in content if x in self.word_to_id]
        data = kr.preprocessing.sequence.pad_sequences([data], 600)
        data = torch.LongTensor(data)
        y_pred_cls = self.model(data)
        class_index = torch.argmax(y_pred_cls[0]).item()
        return self.categories[class_index]


if __name__ == '__main__':
    # model = CnnModel()
    model = RnnModel()
    test_demo = ['三星ST550以全新的拍摄方式超越了以往任何一款数码相机',
                 '热火vs骑士前瞻：皇帝回乡二番战 东部次席唾手可得新浪体育讯北京时间3月30日7:00',
                 '网上有声音称，接种新冠疫苗后患上了白血病，二者是否有关联？5月27日，国务院联防联控机制召开新闻发布会，介绍刻不容缓，抓实抓细疫情防控有关情况。会上，中国疾控中心免疫规划首席专家王华庆介绍，接种疫苗后出现一些症状和疾病，其和疫苗接种是否有关，相关的判断需要遵守规范流程，并要有依据。如果有怀疑，要报告给接种单位，由接种单位组成多领域专家组，收集相关资料。王华庆表示，疫苗的异常反应要由多学科专家组成的专家组，根据调查的内容进行分析和判断，判断不良反应需要考虑六个方面。首先是时间上的关联；第二是要具备生物学合理性，例如减毒活疫苗一般不给具有免疫缺陷的人接种；第三是关联的强度，通过统计学分析要有显著性差异；第四是异常反应的发生本身也具备规律性，如剂量较大的疫苗导致发热的可能性就更高；第五是关联上的一致性，即打了疫苗的这些人是否出现了类似的症状和疾病，症状是否高于基线水平；第六是关联的特异性，即疫苗是否是唯一的因素。王华庆建议大家如果接种疫苗后出现身体不适，特别是症状较重时要及时就医，如果怀疑与疫苗有关，会有相应人员开展上述调查。（北京日报客户端记者 刘苏雅 实习记者 何蕊）'
                 ]
    for i in test_demo:
        print(i)
        print(model.predict(i))
